{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c6c6e585-6d14-4b6a-8621-6395518b463e",
   "metadata": {
    "tags": []
   },
   "source": [
    "# EasyCV图像分类-SwinTransformer\n",
    "\n",
    "本文将介绍如何使用EasyCV快速使用[Swin Transformer](https://arxiv.org/abs/2103.14030) 进行图像分类模型的训练、推理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0512498b-a900-4969-9563-67029a1009e2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## 运行环境要求\n",
    "\n",
    "PAI-Pytorch镜像 or 原生Pytorch1.5+以上环境 GPU机器， 内存32G以上"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cafa45c-b04b-4471-b8f0-9c88ce50cc0e",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 安装依赖包\n",
    "\n",
    "注: 在PAI-DSW docker中无需安装相关依赖，可跳过此部分 在本地notebook环境中执行\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8280983-39b8-414d-a712-c40fd68dadfe",
   "metadata": {},
   "source": [
    "1、 首先，安装pytorch和对应版本的torchvision，支持Pytorch1.5.1以上版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9487def6-9441-435c-9da6-633cd09c517f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# install pytorch and torch vision\n",
    "! conda install --yes pytorch==1.10.0 torchvision==0.11.0 -c pytorch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b6b6a2-0106-43df-90a5-7b315e91cd64",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "2、 获取torch和cuda版本，安装对应版本的mmcv和nvidia-dali"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f93a5657-fdf1-467b-847d-e48fedb362ef",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')\n",
    "os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')\n",
    "!echo \"cuda version: $CUDA\"\n",
    "!echo \"pytorch version: $Torch\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c0d7f7d-bbdc-45ac-b738-e31e00d5eef4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# install some python deps\n",
    "! pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/${CUDA}/${Torch}/index.html\n",
    "! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a987c04-e123-42b7-9d42-240e835aa410",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "3、  安装EasyCV算法包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee7590e1-c312-444e-af84-472459174d15",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "pip install pai-easycv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df0fc82c-1f2c-4869-92e6-76d7dfb34164",
   "metadata": {
    "tags": []
   },
   "source": [
    "4、 简单验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b89751f-f0c0-4aad-aced-8f6c77a4a2aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from easycv.apis import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8df2d210-4bcf-4b41-a981-b264c41e91a7",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Cifar10 分类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fce3f492-cfe0-4121-8f7d-877b31e3438b",
   "metadata": {},
   "source": [
    "下面示例介绍如何利用[cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)数据，使用ResNet50模型快速进行图像分类模型的训练评估、模型预测过程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e769ad6-3ae0-4d5b-bf23-55673a239aa8",
   "metadata": {},
   "source": [
    "### 数据准备\n",
    "下载cifar10数据，解压到`data/cifar`目录， 目录结构如下\n",
    "\n",
    "```text\n",
    "data/cifar\n",
    "└── cifar-10-batches-py\n",
    "    ├── batches.meta\n",
    "    ├── data_batch_1\n",
    "    ├── data_batch_2\n",
    "    ├── data_batch_3\n",
    "    ├── data_batch_4\n",
    "    ├── data_batch_5\n",
    "    ├── readme.html\n",
    "    ├── read.py\n",
    "    └── test_batch\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62eeff3d-274d-4ff9-8123-8e11d97c0d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "! mkdir -p data/cifar && wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/cifar-10-python.tar.gz &&  tar -zxf cifar-10-python.tar.gz -C data/cifar/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed28f2be-1519-4d81-9d03-42891d1e9bc7",
   "metadata": {},
   "source": [
    "### 训练模型\n",
    "下载训练配置文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd8b6d19-652c-4a58-89d3-38394dbb9a20",
   "metadata": {},
   "outputs": [],
   "source": [
    "! rm -rf r50.py\n",
    "!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/classification/cifar10/swintiny_b64_5e_jpg.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d2ed375-f5ef-475e-9dc8-3386d328b8b0",
   "metadata": {},
   "source": [
    "使用单卡gpu进行训练和验证集评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6ec6da-d125-4ebe-b301-0d8cc8881512",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "! python -m easycv.tools.train  swintiny_b64_5e_jpg.py --work_dir work_dirs/classification/cifar10/swin_tiny"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82b04e66-ef62-4d2e-963f-34015a588786",
   "metadata": {},
   "source": [
    "### 导出模型\n",
    "\n",
    "模型训练完成，使用export命令导出模型进行推理，导出的模型包含推理时所需的预处理信息、后处理信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d4f7868-88aa-428c-9852-22e3eca45a07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查看训练产生的pt文件\n",
    "! ls  work_dirs/classification/cifar10/swin_tiny*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5dc516f-fe35-4b1f-acbe-9ec96b72cc95",
   "metadata": {},
   "source": [
    "ClsEvaluator_neck_top1_best.pth 是训练过程中产生的acc最高的pth,导出该模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67aa8262-834a-4afc-a6cd-fd882a181371",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m easycv.tools.export swintiny_b64_5e_jpg.py work_dirs/classification/cifar10/swin_tiny/ClsEvaluator_neck_top1_best.pth  work_dirs/classification/cifar10/swin_tiny/best_export.pth"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72d78ef4-0d90-4f3a-a1d9-c532e17d065b",
   "metadata": {},
   "source": [
    "### 预测\n",
    "下载测试图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "732fa5b4-fe0f-4a9b-ba19-d73d3ebd1736",
   "metadata": {},
   "outputs": [],
   "source": [
    "! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/cifar10/qince_data/predict/aeroplane_s_000004.png"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb3b24c2-7a19-4618-8bdf-42ebe1cdfeb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "from easycv.predictors.classifier import TorchClassifier\n",
    "\n",
    "output_ckpt = 'work_dirs/classification/cifar10/swin_tiny/best_export.pth'\n",
    "tcls = TorchClassifier(output_ckpt, topk=1)\n",
    "\n",
    "img = cv2.imread('aeroplane_s_000004.png')\n",
    "# input image should be RGB order\n",
    "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "output = tcls.predict([img])\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e617239-b034-4760-8c55-326cf3a6e475",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
